package org.apache.spark.ml.feature

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ParamMap, StringArrayParam}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MLWritable}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import scala.collection.mutable.ArrayBuffer

class ColumnCastType(override val uid: String) extends Transformer with MLWritable with DefaultParamsWritable {

  final val castColumnsName: StringArrayParam = new StringArrayParam(this, "castColumnsName", "cast  Columns Name");
  final val castColumnsType: StringArrayParam = new StringArrayParam(this, "castColumnsType", "cast  Columns Type");

  def this() = this(Identifiable.randomUID("ColumnTypeCast"))

  def setCastColumns(value: Map[String, DataType]): this.type = {
    val names = ArrayBuffer[String]()
    val types = ArrayBuffer[String]()
    value.foreach(item => {
      names += item._1; types += item._2.json
    })
    set(castColumnsName, names.toArray)
    set(castColumnsType, types.toArray)
  }

  def getCastColumnsName() = ${castColumnsName}

  def getCastColumns: Map[String, DataType] = {
    var value = Map[String, DataType]()
    val names = $ {castColumnsName}
    if (!names.isEmpty) {
      val types = $ {castColumnsType}
      for (i <- 0 until names.length) {
        value += (names.apply(i) -> DataType.fromJson(types.apply(i)))
      }
    }
    value
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.select(this.transformSchema(dataset.schema).map(item => dataset.col(item.name).cast(item.dataType)): _*)
  }

  override def transformSchema(schema: StructType): StructType = {
    val castColumns = this.getCastColumns
    if (castColumns.isEmpty) {
      StructType(schema)
    }
    else {
      StructType(schema.map(item => {
        val thisTypeName = item.dataType.typeName
        if (castColumns.contains(item.name) && !thisTypeName.equals(castColumns(item.name).typeName)) {
          StructField(item.name, castColumns(item.name), item.nullable, item.metadata)
        }
        else {
          item
        }
      })
      )
    }
  }

  override def copy(extra: ParamMap): ColumnCastType = defaultCopy(extra)
}

object ColumnCastType extends DefaultParamsReadable[ColumnCastType] {

  override def load(path: String): ColumnCastType = super.load(path)

}